name: mnist-gan

resources:
  candidates:
  - {accelerators: T4}
  - {accelerators: V100}

workdir: ./examples/benchmark/lightning_gan

setup: |
  conda create -n pl python=3.8 -y
  conda activate pl

  # Install SkyCallback
  pip install "git+https://github.com/skypilot-org/skypilot.git#egg=sky-callback&subdirectory=sky/callbacks/"

  # User setup
  pip install "torchvision" "pytorch-lightning>=1.4" "torch>=1.6, <1.9"
  git clone https://github.com/Lightning-AI/tutorials.git
  cd tutorials
  git checkout e22e229921a97ea241277e19e0eaddedc35808cb

  # Apply the patch to enable SkyCallback
  git apply ../callback.patch

run: |
  conda activate pl
  cd tutorials/lightning_examples/basic-gan/
  python gan.py
